import numpy as np
import timeautoencoder as tae
import timediffusion_cond as tdf
import DP_Sliding as dp
import pandas as pd
import torch
import os
import time
import process_edited as pce
import correl as correl
import Metrics as mt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def C_TimeAutoDiff(real_df, cond_real_df, response_train, cond_train, time_info_train, VAE_training, diff_training, lat_dim):
    ###### Auto-encoder Training ######
    n_epochs = VAE_training; eps = 1e-5
    weight_decay = 1e-6; lr = 2e-4; hidden_size = 200; num_layers = 2; batch_size = 100
    channels = 64; min_beta = 1e-5; max_beta = 0.1; emb_dim = 128; time_dim = 8; threshold = 1
    ds = tae.train_autoencoder(real_df, response_train, channels, hidden_size, num_layers, lr, weight_decay, n_epochs, batch_size, threshold, min_beta, max_beta, emb_dim, time_dim, lat_dim, device)
    
    latent_features = ds[1]
    
    ###### Diffusion Training ######
    n_epochs = diff_training; hidden_dim = 200; num_layers = 2; diffusion_steps = 100; num_classes = len(latent_features)
    diff = tdf.train_diffusion(latent_features, cond_real_df, cond_train.to(device), time_info_train.to(device), hidden_dim, num_layers, diffusion_steps, n_epochs, num_classes)

    return (ds, diff, latent_features)

def cond_sampling(real_df, ds, diff, cond_test, time_info_test, Batch_size, Seq_len, Lat_dim):
    t_grid = torch.linspace(0, 1, Seq_len).view(1, -1, 1).to(device)
    samples = tdf.sample(t_grid.repeat(Batch_size, 1, 1), Batch_size, Seq_len, Lat_dim, diff, cond_test.to(device), time_info_test.to(device))
    
    ##################################################################################################################
    # Post-process the generated data 
    gen_output = ds[0].decoder(samples.to(device))  # Apply decoder to generated latent vector
    synth_data = pce.convert_to_tensor(real_df, gen_output, 1, Batch_size, Seq_len)
    #_synth_data = pce.convert_to_table(real_df, synth_data, 1)

    return synth_data